import jax
import jax.numpy as jnp
import equinox as eqx
import functools

def top_k_mask(codes, k):
    """Keep only the top-k activations per row in codes; zero the rest."""
    indices = jnp.argsort(-codes, axis=-1)[:, :k]
    mask = jnp.zeros_like(codes)
    batch_indices = jnp.arange(codes.shape[0])[:, None]
    mask = mask.at[batch_indices, indices].set(1.0)
    return mask

class Autoencoder(eqx.Module):
    encoder: jnp.ndarray
    bias: jnp.ndarray
    use_bias: bool
    k: int

    def __init__(self, latent_dim: int, input_dim: int, k: int, use_bias: bool = True, key=None):
        initializer = jax.nn.initializers.he_uniform(in_axis=-1, out_axis=-2)
        self.encoder = initializer(key, (latent_dim, input_dim), jnp.float32)
        self.bias = jnp.zeros(input_dim) if use_bias else None
        self.use_bias = use_bias
        self.k = k

    def encode(self, x):
        if self.use_bias:
            x = x - self.bias
        codes = self.encoder @ x
        return codes

    def top_k_encode(self, x):
        codes = self.encode(x)
        topk_indices = jnp.argsort(-codes)[:self.k]
        mask = jnp.zeros_like(codes).at[topk_indices].set(1.0)
        return codes * mask

    def batch_top_k_encode(self, x_batch):
        codes = jax.vmap(self.encode)(x_batch)
        mask = top_k_mask(codes, self.k)
        return codes * mask

    def decode(self, z):
        decoder = self.encoder.T
        return z @ decoder.T + self.bias if self.use_bias else z @ decoder.T

    def batch_decode(self, z_batch):
        return jax.vmap(self.decode)(z_batch)

@functools.partial(eqx.filter_value_and_grad, has_aux=False)
def loss_fn(model: Autoencoder, batch: jnp.ndarray):
    z = model.batch_top_k_encode(batch)
    x_hat = model.batch_decode(z)
    reconstruction_loss = jnp.mean(jnp.sum((batch - x_hat) ** 2, axis=-1))
    return reconstruction_loss

@eqx.filter_jit
def train_step(model: Autoencoder, batch: jnp.ndarray, opt_state, optimizer):
    (loss), grads = loss_fn(model, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_opt_state, loss